"""
Phase 2: Cross-Sectional Predictors of Threshold Crossing
Development Threshold Paper

Logit/probit models for P(crossing), variable importance, speed regressions.
Key test: demographics and income are separate state-space dimensions (Paper 18).
"""

import pandas as pd
import numpy as np
from scipy import stats
import statsmodels.api as sm
import sys
sys.path.insert(0, '/mnt/c/demographics_capital_flows/multilateral/src')

# ═══════════════════════════════════════════════════════════════════════
# Load crossing data from Phase 1
# ═══════════════════════════════════════════════════════════════════════
cross_df = pd.read_csv('/mnt/c/demographics_capital_flows/development_threshold/data/crossing_data.csv')

LOWER = 9000
UPPER = 25000

# Zone entrants only (exclude "always above" and "still below" with no zone entry)
zone = cross_df[cross_df['cross_9k'].notna()].copy()
zone = zone[zone['status'] != 'Always above'].copy()

print(f"Zone entrants: {len(zone)}")
print(f"  Crossed: {zone['exited_above'].sum()}")
print(f"  Did not cross: {(zone['exited_above'] == 0).sum()}")

# ═══════════════════════════════════════════════════════════════════════
# 2a. Bivariate comparisons with Wilcoxon rank-sum
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2a. BIVARIATE COMPARISONS: CROSSERS vs NON-CROSSERS")
print("=" * 70)

test_vars = ['Z_1_entry', 'old_dep_entry', 'youth_dep_entry', 'working_age_share_entry',
             'kaopen_entry', 'ca_gdp_entry', 'gross_savings_gdp_entry', 'fiscal_bal_gdp_entry',
             'resource_rents_gdp_entry', 'rgdp_growth_entry', 'gdp_pc_ppp_entry',
             'trade_openness_entry', 'life_expectancy_entry', 'human_capital_entry',
             'gross_investment_gdp_entry']

bivariate_results = []
print(f"\n{'Variable':30s} {'Crossers':>10s} {'Non-cross':>10s} {'Diff':>8s} {'t p-val':>8s} {'W p-val':>8s}")
print("-" * 80)

for var in test_vars:
    if var not in zone.columns:
        continue
    c = zone[zone['exited_above'] == 1][var].dropna()
    nc = zone[zone['exited_above'] == 0][var].dropna()
    if len(c) < 3 or len(nc) < 3:
        continue

    t_stat, t_p = stats.ttest_ind(c, nc, equal_var=False)
    w_stat, w_p = stats.mannwhitneyu(c, nc, alternative='two-sided')
    diff = c.mean() - nc.mean()
    sig = '***' if min(t_p, w_p) < 0.01 else '**' if min(t_p, w_p) < 0.05 else '*' if min(t_p, w_p) < 0.1 else ''

    print(f"  {var.replace('_entry',''):28s} {c.mean():10.3f} {nc.mean():10.3f} {diff:+8.3f} {t_p:8.4f} {w_p:8.4f} {sig}")

    bivariate_results.append({
        'variable': var.replace('_entry', ''),
        'crossers_mean': c.mean(), 'non_crossers_mean': nc.mean(),
        'difference': diff, 't_pvalue': t_p, 'wilcoxon_pvalue': w_p,
        'n_crossers': len(c), 'n_non_crossers': len(nc)
    })

pd.DataFrame(bivariate_results).to_csv(
    '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table4_bivariate.csv', index=False)

# ═══════════════════════════════════════════════════════════════════════
# 2b. Logit models: P(crossing) = f(entry characteristics)
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2b. LOGIT MODELS: P(CROSSING)")
print("=" * 70)

def run_logit(df, y_var, x_vars, label):
    """Run logit and return results dict."""
    subset = df[[y_var] + x_vars].dropna()
    if len(subset) < 20:
        print(f"  {label}: insufficient observations ({len(subset)})")
        return None

    y = subset[y_var]
    X = sm.add_constant(subset[x_vars])

    try:
        model = sm.Logit(y, X).fit(disp=0)
        print(f"\n  {label} (n={len(subset)}, pseudo-R²={model.prsquared:.3f}):")
        for var in x_vars:
            coef = model.params[var]
            pval = model.pvalues[var]
            sig = '***' if pval < 0.01 else '**' if pval < 0.05 else '*' if pval < 0.1 else ''
            # Marginal effect (at means)
            mfx = coef * model.predict(X).mean() * (1 - model.predict(X).mean())
            print(f"    {var:30s}  coef={coef:8.4f}  p={pval:.4f}{sig:3s}  mfx={mfx:.4f}")

        return {
            'label': label,
            'n': len(subset),
            'pseudo_r2': model.prsquared,
            'aic': model.aic,
            'bic': model.bic,
            'params': model.params.to_dict(),
            'pvalues': model.pvalues.to_dict(),
            'model': model,
        }
    except Exception as e:
        print(f"  {label}: failed — {e}")
        return None

# Model 1: Demographics only
m1 = run_logit(zone, 'exited_above',
               ['Z_1_entry'],
               'M1: Demographics only')

# Model 2: Demographics + income level
m2 = run_logit(zone, 'exited_above',
               ['Z_1_entry', 'gdp_pc_ppp_entry'],
               'M2: Demographics + income')

# Model 3: Income only (comparison)
m3 = run_logit(zone, 'exited_above',
               ['gdp_pc_ppp_entry'],
               'M3: Income only')

# Model 4: Demographics + institutions
m4 = run_logit(zone, 'exited_above',
               ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'trade_openness_entry'],
               'M4: Demographics + income + institutions')

# Model 5: Full model
m5 = run_logit(zone, 'exited_above',
               ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'trade_openness_entry',
                'gross_savings_gdp_entry', 'fiscal_bal_gdp_entry'],
               'M5: Full model')

# Model 6: Full + resources
m6 = run_logit(zone, 'exited_above',
               ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'trade_openness_entry',
                'gross_savings_gdp_entry', 'fiscal_bal_gdp_entry', 'resource_rents_gdp_entry'],
               'M6: Full + resources')

# Model 7: Kitchen sink with demographic detail
m7 = run_logit(zone, 'exited_above',
               ['Z_1_entry', 'gdp_pc_ppp_entry', 'old_dep_entry', 'youth_dep_entry',
                'kaopen_entry', 'gross_savings_gdp_entry', 'resource_rents_gdp_entry',
                'rgdp_growth_entry'],
               'M7: Kitchen sink')

# ═══════════════════════════════════════════════════════════════════════
# 2c. Variable importance: pseudo-R² contribution
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2c. VARIABLE IMPORTANCE (PSEUDO-R² CONTRIBUTION)")
print("=" * 70)

# Use parsimonious set to avoid perfect separation (small cross-section)
base_vars = ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'resource_rents_gdp_entry',
             'rgdp_growth_entry']

# Single-variable pseudo-R² (most robust: no multicollinearity)
all_importance_vars = ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'gross_savings_gdp_entry',
                       'fiscal_bal_gdp_entry', 'resource_rents_gdp_entry', 'rgdp_growth_entry',
                       'trade_openness_entry', 'life_expectancy_entry', 'human_capital_entry']

importance = []
print("\nSingle-variable pseudo-R² (univariate logit):")
for var in all_importance_vars:
    if var not in zone.columns:
        continue
    single_sample = zone[['exited_above', var]].dropna()
    if len(single_sample) < 15 or single_sample['exited_above'].sum() < 3:
        continue
    try:
        single_model = sm.Logit(single_sample['exited_above'],
                               sm.add_constant(single_sample[[var]])).fit(disp=0, maxiter=100)
        sig = '***' if single_model.pvalues[var] < 0.01 else '**' if single_model.pvalues[var] < 0.05 else '*' if single_model.pvalues[var] < 0.1 else ''
        print(f"  {var.replace('_entry',''):25s}: R²={single_model.prsquared:.4f}  coef={single_model.params[var]:.4f}  p={single_model.pvalues[var]:.4f}{sig}")
        importance.append({'variable': var.replace('_entry', ''), 'single_r2': single_model.prsquared,
                          'coef': single_model.params[var], 'pvalue': single_model.pvalues[var]})
    except:
        print(f"  {var.replace('_entry',''):25s}: failed")

# Drop-one importance using parsimonious model
common = zone[['exited_above'] + base_vars].dropna()
print(f"\nDrop-one importance (parsimonious model, n={len(common)}):")

if len(common) >= 20:
    y = common['exited_above']
    try:
        X_full = sm.add_constant(common[base_vars])
        full_model = sm.Logit(y, X_full).fit(disp=0, maxiter=100)
        full_pr2 = full_model.prsquared
        print(f"Full model pseudo-R²: {full_pr2:.4f}")

        for var in base_vars:
            remaining = [v for v in base_vars if v != var]
            X_reduced = sm.add_constant(common[remaining])
            try:
                reduced_model = sm.Logit(y, X_reduced).fit(disp=0, maxiter=100)
                contribution = full_pr2 - reduced_model.prsquared
                print(f"  Drop {var.replace('_entry',''):25s}: ΔR²={contribution:+.4f} ({contribution/full_pr2*100:.1f}%)")
                for imp in importance:
                    if imp['variable'] == var.replace('_entry', ''):
                        imp['drop_one_contribution'] = contribution
            except:
                print(f"  Drop {var.replace('_entry',''):25s}: reduced model failed")
    except:
        print("  Full parsimonious model failed")

pd.DataFrame(importance).to_csv(
    '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table5_variable_importance.csv', index=False)

# ═══════════════════════════════════════════════════════════════════════
# 2d. Nested models: demographics → institutions → resources → macro
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2d. NESTED MODELS: INCREMENTAL PSEUDO-R²")
print("=" * 70)

nesting = [
    ('Income only', ['gdp_pc_ppp_entry']),
    ('+ Demographics', ['gdp_pc_ppp_entry', 'Z_1_entry']),
    ('+ Demographics (detail)', ['gdp_pc_ppp_entry', 'Z_1_entry', 'old_dep_entry', 'youth_dep_entry']),
    ('+ Institutions', ['gdp_pc_ppp_entry', 'Z_1_entry', 'kaopen_entry', 'trade_openness_entry']),
    ('+ Savings/Fiscal', ['gdp_pc_ppp_entry', 'Z_1_entry', 'kaopen_entry', 'trade_openness_entry',
                          'gross_savings_gdp_entry', 'fiscal_bal_gdp_entry']),
    ('+ Resources', ['gdp_pc_ppp_entry', 'Z_1_entry', 'kaopen_entry', 'trade_openness_entry',
                     'gross_savings_gdp_entry', 'fiscal_bal_gdp_entry', 'resource_rents_gdp_entry']),
    ('+ Growth', ['gdp_pc_ppp_entry', 'Z_1_entry', 'kaopen_entry', 'trade_openness_entry',
                  'gross_savings_gdp_entry', 'fiscal_bal_gdp_entry', 'resource_rents_gdp_entry',
                  'rgdp_growth_entry']),
]

# Use common sample for fair comparison
all_vars = ['exited_above', 'gdp_pc_ppp_entry', 'Z_1_entry', 'old_dep_entry', 'youth_dep_entry',
            'kaopen_entry', 'trade_openness_entry', 'gross_savings_gdp_entry', 'fiscal_bal_gdp_entry',
            'resource_rents_gdp_entry', 'rgdp_growth_entry']
common_nest = zone[[v for v in all_vars if v in zone.columns]].dropna()
print(f"Common sample for nesting: {len(common_nest)}")

nested_results = []
prev_r2 = 0
for label, xvars in nesting:
    available = [v for v in xvars if v in common_nest.columns]
    if len(available) == 0:
        continue
    y = common_nest['exited_above']
    X = sm.add_constant(common_nest[available])
    try:
        model = sm.Logit(y, X).fit(disp=0)
        incr = model.prsquared - prev_r2
        print(f"  {label:30s}  pseudo-R²={model.prsquared:.4f}  incremental={incr:+.4f}  AIC={model.aic:.1f}")
        nested_results.append({
            'model': label, 'pseudo_r2': model.prsquared, 'incremental_r2': incr,
            'aic': model.aic, 'n': len(common_nest)
        })
        prev_r2 = model.prsquared
    except:
        pass

pd.DataFrame(nested_results).to_csv(
    '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table6_nested_models.csv', index=False)

# ═══════════════════════════════════════════════════════════════════════
# 2e. Speed of crossing: OLS on years_in_zone
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2e. SPEED OF CROSSING (OLS: years_in_zone)")
print("=" * 70)

crossers = zone[zone['exited_above'] == 1].copy()
print(f"\nCrossers with transit time: {crossers['years_in_zone'].notna().sum()}")

speed_vars = ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'gross_savings_gdp_entry',
              'resource_rents_gdp_entry', 'rgdp_growth_entry']

speed_sample = crossers[['years_in_zone'] + [v for v in speed_vars if v in crossers.columns]].dropna()
if len(speed_sample) >= 10:
    y = speed_sample['years_in_zone']
    available = [v for v in speed_vars if v in speed_sample.columns]
    X = sm.add_constant(speed_sample[available])
    ols = sm.OLS(y, X).fit()
    print(ols.summary2().tables[1])

    # Save
    speed_results = pd.DataFrame({
        'variable': ols.params.index,
        'coefficient': ols.params.values,
        'std_error': ols.bse.values,
        'p_value': ols.pvalues.values,
    })
    speed_results.to_csv(
        '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table7_speed_ols.csv', index=False)
else:
    print(f"  Insufficient observations: {len(speed_sample)}")

# ═══════════════════════════════════════════════════════════════════════
# 2f. Falling back: logit P(fell below) among zone entrants
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2f. FALLING BACK: P(fell below $9k)")
print("=" * 70)

# Load full panel to check who fell back
full = pd.read_csv('/mnt/c/demographics_capital_flows/multilateral/140_country/data/processed/full_panel.csv')
fpanel = full[full['year'] <= 2024].copy()

for idx, row in zone.iterrows():
    iso = row['iso3']
    if pd.isna(row['cross_9k']):
        zone.loc[idx, 'fell_back'] = 0
        continue
    c = fpanel[(fpanel['iso3'] == iso) & (fpanel['year'] > row['cross_9k']) & (fpanel['gdp_pc_ppp'].notna())]
    zone.loc[idx, 'fell_back'] = int((c['gdp_pc_ppp'] < LOWER).any()) if len(c) > 0 else 0

fell_back_n = zone['fell_back'].sum()
print(f"Countries that fell back below $9k after entering zone: {int(fell_back_n)}")

if fell_back_n >= 5:
    fb_vars = ['Z_1_entry', 'gdp_pc_ppp_entry', 'kaopen_entry', 'resource_rents_gdp_entry',
               'rgdp_growth_entry', 'gross_savings_gdp_entry']
    available = [v for v in fb_vars if v in zone.columns]
    fb_sample = zone[['fell_back'] + available].dropna()

    if len(fb_sample) >= 15 and fb_sample['fell_back'].sum() >= 3:
        y = fb_sample['fell_back']
        X = sm.add_constant(fb_sample[available])
        try:
            fb_model = sm.Logit(y, X).fit(disp=0)
            print(f"\n  Logit P(fell back), n={len(fb_sample)}, pseudo-R²={fb_model.prsquared:.3f}")
            for var in available:
                coef = fb_model.params[var]
                pval = fb_model.pvalues[var]
                sig = '***' if pval < 0.01 else '**' if pval < 0.05 else '*' if pval < 0.1 else ''
                print(f"    {var.replace('_entry',''):25s}  coef={coef:8.4f}  p={pval:.4f}{sig}")
        except:
            print("  Logit failed (likely separation)")
    else:
        print(f"  Insufficient variation for logit (n={len(fb_sample)}, fell_back={fb_sample['fell_back'].sum()})")

# ═══════════════════════════════════════════════════════════════════════
# 2g. Within-band analysis: does Z₁ predict within income sub-bands?
# ═══════════════════════════════════════════════════════════════════════
print("\n" + "=" * 70)
print("2g. WITHIN-BAND ANALYSIS: Z₁ PREDICTIVE POWER BY ENTRY GDP LEVEL")
print("=" * 70)

# Split zone entrants into GDP terciles at entry
zone_with_gdp = zone[zone['gdp_pc_ppp_entry'].notna() & zone['Z_1_entry'].notna()].copy()
zone_with_gdp['gdp_tercile'] = pd.qcut(zone_with_gdp['gdp_pc_ppp_entry'], 3, labels=['Low ($9k-mid)', 'Mid', 'High (mid-$25k)'])

for tercile in ['Low ($9k-mid)', 'Mid', 'High (mid-$25k)']:
    sub = zone_with_gdp[zone_with_gdp['gdp_tercile'] == tercile]
    n_cross = sub['exited_above'].sum()
    n_total = len(sub)

    c_z1 = sub[sub['exited_above'] == 1]['Z_1_entry'].dropna()
    nc_z1 = sub[sub['exited_above'] == 0]['Z_1_entry'].dropna()

    if len(c_z1) >= 2 and len(nc_z1) >= 2:
        t, p = stats.ttest_ind(c_z1, nc_z1, equal_var=False)
        sig = '***' if p < 0.01 else '**' if p < 0.05 else '*' if p < 0.1 else ''
        gdp_range = f"${sub['gdp_pc_ppp_entry'].min():,.0f}-${sub['gdp_pc_ppp_entry'].max():,.0f}"
        print(f"  {tercile:20s} ({gdp_range:20s}): n={n_total}, crossed={n_cross}, "
              f"Z₁ crossers={c_z1.mean():.3f}, Z₁ non-crossers={nc_z1.mean():.3f}, p={p:.4f}{sig}")
    else:
        print(f"  {tercile}: n={n_total}, crossed={n_cross} — insufficient for t-test")

# ═══════════════════════════════════════════════════════════════════════
# Save logit results table
# ═══════════════════════════════════════════════════════════════════════
logit_summary = []
for label, result in [('M1: Demographics only', m1), ('M2: Demographics + income', m2),
                       ('M3: Income only', m3), ('M4: + institutions', m4),
                       ('M5: Full', m5), ('M6: + resources', m6), ('M7: Kitchen sink', m7)]:
    if result is not None:
        row = {'model': label, 'n': result['n'], 'pseudo_r2': result['pseudo_r2'],
               'aic': result['aic'], 'bic': result['bic']}
        row.update({f'coef_{k}': v for k, v in result['params'].items()})
        row.update({f'pval_{k}': v for k, v in result['pvalues'].items()})
        logit_summary.append(row)

pd.DataFrame(logit_summary).to_csv(
    '/mnt/c/demographics_capital_flows/development_threshold/output/tables/table8_logit_models.csv', index=False)

print("\n✓ Phase 2 complete.")
